import os
import torch
import numpy as np
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_PATH = os.environ.get('MODEL_PATH', '~/dual-map/model/dual_map_mlp_model.pt')
print(f"MODEL_PATH: {MODEL_PATH}")
GEMMA_MODEL_NAME = os.environ.get('GEMMA_MODEL_NAME', "google/gemma-2-2b")
print(f"GEMMA_MODEL_NAME: {GEMMA_MODEL_NAME}")

PROMPT = os.environ.get('PROMPT', "The capital of France is")
print(f"Testing prompt is: {PROMPT}")

########################################################
# Define MLP here
########################################################

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        hidden_dim = hidden_dim * 2
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
        )
    
    def forward(self, x):
        return self.layers(x/np.sqrt(self.input_dim))

########################################################
# Predictor
########################################################
# Define the MLP model (must match the architecture used during training)
class GemmaEmbeddingPredictor:
    def __init__(self, mlp_model_path=MODEL_PATH, gemma_model_name=GEMMA_MODEL_NAME):
        """
        Initialize the embedding predictor by loading both the Gemma model and the trained MLP
        
        Args:
            mlp_model_path: Path to the trained MLP model weights
            gemma_model_name: Name or path of the Gemma model
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Load Gemma model and tokenizer
        print(f"Loading Gemma model and tokenizer: {gemma_model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(gemma_model_name)
        
        # Load in float16 for memory efficiency, but will convert embeddings to float32 when needed
        self.gemma_model = AutoModelForCausalLM.from_pretrained(
            gemma_model_name, 
            device_map="auto",
            output_hidden_states=True,
            cache_dir = "~/gemma_cache"
        )
        
        # Get embedding dimensions from the model
        self.input_dim = self.gemma_model.config.hidden_size
        self.output_dim = self.input_dim
        
        # Load the MLP model
        print(f"Loading MLP model from {mlp_model_path}")
        self.mlp_model = self._load_mlp_model(mlp_model_path)
        self.mlp_model.eval()  # Set to evaluation mode

        original_g = self.gemma_model.get_input_embeddings().weight
        self.original_g = original_g

        g_mean = original_g.mean(axis=0)
        g_modfied = original_g - g_mean
        
        u, s, vt = torch.linalg.svd(g_modfied, full_matrices=False)

        self.g = u @ vt
    
    
        self.g_mean = g_mean.detach().cpu()
        self.whitening_matrix = torch.matmul(
            torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
            vt
        ).detach().cpu()
    
    def _load_mlp_model(self, model_path):
        """
        Load the trained MLP model
        
        Args:
            model_path: Path to the saved model weights
            
        Returns:
            Loaded MLP model
        """
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        # Create model with the correct dimensions
        model = MLP(self.input_dim, self.output_dim).to(self.device).float()
        
        # Load weights
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        return model
    
    def get_last_token_embedding(self, prompt):
        """
        Get the last token embedding from Gemma model for a prompt
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Embedding of the last token
        """
        # Tokenize the prompt
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Get the embedding with no gradient tracking
        with torch.no_grad():
            outputs = self.gemma_model(**inputs, output_hidden_states=True)
            last_hidden_states = outputs.hidden_states[-1]  # Get the last layer
            last_token_embedding = last_hidden_states[0, -1].detach().float()
        
        return last_token_embedding
    
    def predict_next_token_embedding(self, prompt):
        """
        Predict the expected next token embedding using the trained MLP
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Predicted expected embedding of the next token
        """
        # Get the last token embedding
        last_token_embedding = self.get_last_token_embedding(prompt)

        # Predict the expected next token embedding
        with torch.no_grad():
            predicted_embedding = self.mlp_model(last_token_embedding)
        
        return predicted_embedding
    
    def compare_with_actual(self, prompt):
        """
        Compare the MLP-predicted embedding with the actual expected embedding
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Dictionary with predicted and actual embeddings plus cosine similarity
        """
        # Get the last token embedding (x)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            # Forward pass
            outputs = self.gemma_model(**inputs, output_hidden_states=True)
            
            # Get the last token embedding
            last_hidden_states = outputs.hidden_states[-1]
            last_token_embedding = last_hidden_states[0, -1].detach().float()
            
            # Get the output token distribution from logits
            logits = outputs.logits
            token_logits = logits[0, -1]
            token_probs = torch.softmax(token_logits, dim=0)
            
            # Calculate the actual expected embedding
            word_embeddings = self.gemma_model.get_input_embeddings().weight
            actual_embedding = torch.matmul(token_probs, word_embeddings).detach().float().cpu()
            actual_embedding = (actual_embedding - self.g_mean) @ self.whitening_matrix  * np.sqrt(self.original_g.shape[0] / self.original_g.shape[1])
            
            # Predict with MLP
            predicted_embedding = self.mlp_model(last_token_embedding).cpu()

            print(last_token_embedding, torch.norm(last_token_embedding))
            print(actual_embedding, torch.norm(actual_embedding))
            print(predicted_embedding, torch.norm(predicted_embedding))
            
            # Calculate cosine similarity
            cos_sim = torch.nn.functional.cosine_similarity(
                predicted_embedding.unsqueeze(0), 
                actual_embedding.unsqueeze(0)
            ).item()

            euclidean_distance = torch.norm(
                predicted_embedding - actual_embedding
            ).item()
        
        return {
            "predicted_embedding": predicted_embedding.cpu().numpy(),
            "actual_embedding": actual_embedding.cpu().numpy(),
            "cosine_similarity": cos_sim
        }

    def identify_closest_tokens(self, embedding, top_k=5):
        """
        Find the tokens with embeddings closest to the given embedding
        
        Args:
            embedding: The embedding vector to compare against
            top_k: Number of closest tokens to return
            
        Returns:
            List of top_k closest tokens with their similarities
        """
        # Get the embedding matrix
        with torch.no_grad():
            # word_embeddings = self.gemma_model.get_input_embeddings().weight.float()
            word_embeddings = self.g
        
        # Normalize embeddings for cosine similarity
        embedding_norm = embedding / embedding.norm()
        word_embeddings_norm = word_embeddings / word_embeddings.norm(dim=1, keepdim=True)
        
        # Calculate cosine similarities
        similarities = torch.matmul(word_embeddings_norm, embedding_norm)
        
        # Get top-k indices
        values, indices = torch.topk(similarities, top_k)
        
        # Convert to tokens
        result = []
        for i, (index, similarity) in enumerate(zip(indices, values)):
            token = self.tokenizer.decode([index.item()])
            result.append({
                "rank": i+1,
                "token": token,
                "similarity": similarity.item()
            })
        
        return result


    def identify_closest_tokens_distance(self, embedding, top_k=5):
        """
        Find the tokens with embeddings closest to the given embedding
        
        Args:
            embedding: The embedding vector to compare against
            top_k: Number of closest tokens to return
            
        Returns:
            List of top_k closest tokens with their similarities
        """
        # Get the embedding matrix
        with torch.no_grad():
            # word_embeddings = self.gemma_model.get_input_embeddings().weight.float()
            word_embeddings = self.g
        
        # Normalize embeddings for cosine similarity
        embedding_norm = embedding 
        word_embeddings_norm = word_embeddings 
        
        # Calculate cosine similarities
        similarities = torch.matmul(word_embeddings_norm, embedding_norm)
        
        # Get top-k indices
        values, indices = torch.topk(similarities, top_k)
        
        # Convert to tokens
        result = []
        for i, (index, similarity) in enumerate(zip(indices, values)):
            token = self.tokenizer.decode([index.item()])
            result.append({
                "rank": i+1,
                "token": token,
                "similarity": similarity.item()
            })

        return result

# Example usage
if __name__ == "__main__":
    # Example prompt
    test_prompt = PROMPT
    predictor = GemmaEmbeddingPredictor(mlp_model_path=MODEL_PATH)

    # Get predicted embedding
    predicted_embedding = predictor.predict_next_token_embedding(test_prompt)
    print(f"Predicted embedding shape: {predicted_embedding.shape}")

    # Compare with the actual expected embedding
    comparison = predictor.compare_with_actual(test_prompt)
    print(f"Cosine similarity: {comparison['cosine_similarity']:.4f}")

    # Find closest tokens to the predicted embedding
    closest_tokens = predictor.identify_closest_tokens(torch.Tensor(comparison['predicted_embedding']).to(predictor.device), top_k=10)
    print("\nTop 5 predicted next tokens:")
    for token_info in closest_tokens:
        print(f"Rank {token_info['rank']}: '{token_info['token']}' (similarity: {token_info['similarity']:.4f})")

    closest_tokens = predictor.identify_closest_tokens(torch.Tensor(comparison['actual_embedding']).to(predictor.device), top_k=10)
    print("\nTop 5 predicted next tokens:")
    for token_info in closest_tokens:
        print(f"Rank {token_info['rank']}: '{token_info['token']}' (similarity: {token_info['similarity']:.4f})")

    closest_tokens = predictor.identify_closest_tokens_distance(torch.Tensor(comparison['actual_embedding']).to(predictor.device), top_k=10)
    print("\nTop 5 predicted next tokens:")
    for token_info in closest_tokens:
        print(f"Rank {token_info['rank']}: '{token_info['token']}' (similarity: {token_info['similarity']:.4f})")